#include "alt.hpp"
#include "probs.hpp"

#include <omp.h>

/*void probs(double* out,const double*t,const double* pars,const int *n)
{
  const double* lam=pars+2;

  for(int i=0;i<*n;++i)
  {
//    std::cerr<<pars[0]<<"#"<<pars[1]<<"\n";

    double pf=1,ps=0;
    if((*t)>0)
    {
      pf=std::exp(-std::pow(((*t)*lam[i]),pars[0]));
      ps=(1-pf)*(1-pars[1]);
      pf=1-ps;
    }
    for(int j=0;j<ustate::maxel+1;++j)
    {
//      std::cerr<<i*(ustate::maxel+1)+j<<":"<<gsl_sf_choose(ustate::maxel,j)*std::pow(ps,double(j))*std::pow(pf,ustate::maxel-double(j));
      out[i*(ustate::maxel+1)+j]=gsl_sf_choose(ustate::maxel,j)*std::pow(ps,double(j))*std::pow(pf,ustate::maxel-double(j));
    }
  }
}*/
void probs(double* out,const double*t,const std::vector<double>& pars,const int *n)
{
//  std::cerr<<omp_get_thread_num()<<"probs\n";

  const double* lam=pars.data()+2;

  for(int i=0;i<*n;++i)
  {
//    std::cerr<<pars[0]<<"#"<<pars[1]<<"\n";

    double p0=1,p1=0,p2=0;
    double l=lam[i],k=pars[0],e=pars[1];
    
    if((*t)>0)
    {
      p0=std::exp(-(*t)*l);
      if(k!=1) p2=1-(1/(k-1))*(p0-exp(-(*t)*k*l))-p0;
      else p2 =1-((*t)*l+1)*p0;
      p1=1-p2-p0;
//      if(p0<0||p1<0||p2<0||p0>1||p1>1||p2>1)      std::cerr<<p0<<","<<p1<<","<<p2<<"\n";
      p0+=e*(p1+p2);
      p1*=(1-e);
      p2*=(1-e);
    }
    
    out[i*(ustate::maxel+1)]=p0;
    out[i*(ustate::maxel+1)+1]=p1;
    out[i*(ustate::maxel+1)+2]=p2;
  }
//  std::cerr<<omp_get_thread_num()<<"probs~\n";
}
